#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016, Fabian Greif
# Copyright (c) 2017, 2024, Niklas Hauser
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

import re
from pathlib import Path
from collections import defaultdict

def getDefineForDevice(device_id, familyDefines):
    """
    Returns the STM32 specific define from an identifier
    """
    # get all defines for this device name
    devName = "STM32{}{}".format(device_id.family.upper(), device_id.name.upper())

    # Map STM32F7x8 -> STM32F7x7
    if device_id.family == "f7" and devName[8] == "8":
        devName = devName[:8] + "7"

    deviceDefines = sorted([define for define in familyDefines if define.startswith(devName)])
    # if there is only one define thats the one
    if len(deviceDefines) == 1:
        return deviceDefines[0]

    # sort with respecting variants
    minlen = min(len(d) for d in deviceDefines)
    deviceDefines.sort(key=lambda d: (d[:minlen], d[minlen:]))

    # now we match for the size-id (and variant-id if applicable).
    if device_id.family == "h7":
        devNameMatch = devName + "xx"
    else:
        devNameMatch = devName + "x{}".format(device_id.size.upper())
    if device_id.family == "l1":
        # Map STM32L1xxQC and STM32L1xxZC -> STM32L162QCxA variants
        if device_id.pin in ["q", "z"] and device_id.size == "c":
            devNameMatch += "A"
        else:
            devNameMatch += device_id.variant.upper()
    elif device_id.family == "h7":
        if device_id.variant:
            devNameMatch += device_id.variant.upper()
    for define in deviceDefines:
        if devNameMatch <= define:
            return define

    # now we match for the pin-id.
    devNameMatch = devName + "{}x".format(device_id.pin.upper())
    for define in deviceDefines:
        if devNameMatch <= define:
            return define

    return None

bprops = {}
def common_rcc_map(env):
    """
    Finds all CMSIS bit fields related to enabling and resetting peripherals
    in the RCC of the format `RCC_(REGISTER)_(PERIPHERAL)_(TYPE)` where:

      - REGISTER: a variation of `(BUS)(ID?)(ENR|RSTR)`, e.g. `AHB1ENR`
      - PERIPHERAL: typical peripheral name, e.g. `GPIOA`
      - TYPE: either `EN` or `RST`.

    :returns: a 2D-dictionary: `map[PERIPHERAL][TYPE] = REGISTER`
    """
    headers = env.query("headers")
    core_header = repopath("ext/arm/cmsis/CMSIS/Core/Include", headers["core_header"])

    content = ""
    for header_path in [core_header, localpath(bprops["folder"], headers["device_header"])]:
        content += Path(header_path).read_text(encoding="utf-8", errors="replace")

    # find mpu and fpu features
    features = re.findall(r"#define +__([MF]PU)_PRESENT +([01])", content)
    core_features = {f[0]:bool(int(f[1])) for f in features}
    # find all peripherals
    mperipherals = re.findall(r"#define +(.*?) +\(\((.*?_Type(?:Def)?)", content)
    # We only care about the absolute peripheral addresses
    peripherals = [(p[0],p[1]) for p in mperipherals]
    # filter out MPU and/or FPU if required
    peripherals = filter(lambda p: p[0] not in core_features or core_features[p[0]], peripherals)
    peripherals = sorted(peripherals, key=lambda p: p[0])
    # print("\n".join([s+" -> "+hex(a) for (s,k,a) in peripherals]))

    # Find all RCC enable and reset definitions
    match = re.findall(r"RCC_([A-Z0-9]*?)_([A-Z0-9]+?)(EN|RST) ", content)
    rcc_map = defaultdict(dict)
    for (reg, per, typ) in match:
        rcc_map[per][typ] = reg

    bprops["peripherals"] = peripherals
    return rcc_map


def common_header_file(env):
    """
    Gives information about the STM32 header files. For example the STM32F469:

      - family_define: `STM32F4`
      - define: `STM32F469xx`
      - core_header: `core_cm4.h`
      - system_header: `system_stm32f4xx.h`
      - family_header: `stm32f4xx.h`
      - device_header: `stm32f469xx.h`

    :returns: a dictionary with the above keys
    """
    device = env[":target"]

    core = device.get_driver("core")["type"]
    core = core.replace("cortex-m", "").replace("+", "plus").replace("f", "").replace("d", "")
    core_header = "core_cm{}.h".format(core)

    folder = "stm32{}xx".format(device.identifier.family)
    family_header = folder + ".h"
    folder = "stm32/{}/Include".format(folder)
    define = None

    content = Path(localpath(folder, family_header)).read_text(encoding="utf-8", errors="replace")
    match = re.findall(r"if defined\((?P<define>STM32[CFGHLU][\w\d]+)\)", content)
    define = getDefineForDevice(device.identifier, match)
    if define is None or match is None:
        raise ValidateException("No device define found for '{}'!".format(device.partname))
    device_header = define.lower() + ".h"
    family_define = "STM32{}".format(device.identifier.family.upper())

    headers = {
        "define": define,
        "family_define": family_define,
        "core_header": core_header,
        "device_header": device_header,
        "family_header": family_header,
        "system_header": "system_" + family_header,
    }
    bprops.update(headers)
    bprops["folder"] = folder
    return headers


def common_peripherals(env):
    """
    All peripherals translated to the modm naming convention.

    :returns: a sorted list of all peripheral names.
    """
    def get_driver(s):
        name = None
        if "driver" in s: name = translate(s["driver"])
        if "instance" in s: name += translate(s["instance"])
        return name

    def translate(s):
        return s.replace("_", "").capitalize()

    device = env[":target"]
    gpio_driver = device.get_driver("gpio")
    # Get all the peripherals from the GPIO remap-able signals
    all_peripherals = {get_driver(remap) for remap in gpio_driver.get("remap", [])}
    # Get all the peripherals from the normal GPIO signals
    all_peripherals.update(get_driver(s) for gpio in gpio_driver["gpio"] for s in gpio.get("signal", []))
    # Get all peripherals from the driver instances
    all_drivers = (d for d in device._properties["driver"] if d["name"] not in ["gpio", "core"])
    for d in all_drivers:
        dname = translate(d["name"])
        if "instance" in d:
            all_peripherals.update(dname + translate(i) for i in d["instance"])
        else:
            all_peripherals.add(dname)
        if dname == "Dma" and d["type"] == "stm32-mux":
            all_peripherals.add("Dmamux1")
        if dname == "Usbotghs":
            all_peripherals.add("Usbotghsulpi")

    all_peripherals.discard(None)
    all_peripherals = sorted(list(all_peripherals))
    bprops["all_peripherals"] = all_peripherals
    return all_peripherals


# -----------------------------------------------------------------------------
def init(module):
    module.name = ":cmsis:device"
    module.description = FileReader("module.md")

def prepare(module, options):
    device = options[":target"]
    if device.identifier["platform"] != "stm32":
        return False

    module.add_query(
        EnvironmentQuery(name="rcc-map", factory=common_rcc_map))
    module.add_query(
        EnvironmentQuery(name="peripherals", factory=common_peripherals))
    module.add_query(
        EnvironmentQuery(name="headers", factory=common_header_file))

    module.depends(":cmsis:core")
    return True

def validate(env):
    env.query("rcc-map")
    env.query("peripherals")

def build(env):
    env.collect(":build:path.include", "modm/ext")
    env.collect(":build:path.include", "modm/ext/cmsis/device")
    env.collect(":build:cppdefines", bprops["family_define"])

    env.outbasepath = "modm/ext/cmsis/device"
    env.copy(localpath(bprops["folder"], bprops["device_header"]), bprops["device_header"])
    env.copy(localpath(bprops["folder"], bprops["system_header"]), bprops["system_header"])
    env.copy(localpath(bprops["folder"], bprops["family_header"]), bprops["family_header"])

    env.substitutions = bprops
    env.substitutions.update({
        "headers": [bprops["device_header"], bprops["system_header"]],
        "defines": [bprops["define"]],
        "target": env[":target"].identifier,
    })
    env.outbasepath = "modm/src/modm/platform"
    env.template("device.hpp.in")

    env.outbasepath = "modm/src/modm/platform/core"
    env.template("peripherals.hpp.in")
